import numpy as np
import os

import time
import utils.env_utils,utils.misc
from utils.reprod_setup import device, set_seed
import src.plotter as plotter
from src.logger import SummaryWriterCallback
from stable_baselines3.common.callbacks import EvalCallback, CallbackList, StopTrainingOnRewardThreshold
from stable_baselines3 import PPO

from stable_baselines3.common.noise import NormalActionNoise, VectorizedActionNoise
from src.custom_callbacks import CustomEval, CustomEarlyStop
from src.policies import CustomSAC, CustomDDPG
import torch
import shutil
from stable_baselines3 import DDPG

class PolicyTrainer():

    def __init__(self, config_params=None, init_set=None, term_set=None, term_sampler = None,policy_type='region_policy'):
        self.device = device
        self.init_set = init_set
        self.term_set = term_set
        self.config_params = config_params
        self.envs = None
        self.eval_env = None
        self.term_sampler = term_sampler
        self.policy_type = policy_type
        self.model_params = self.config_params[self.policy_type]

    def train(self,init_sampler,eval_func, term_sampler,option_guide,region_switch_point,tblog_prefix='',seed=1337):
        if self.model_params.get('net_arch') is None:
            self.model_params['net_arch'] = [128,64]

        if self.envs is None:
            self.envs = utils.env_utils.make_subproc_envs(num=self.model_params['train_envs'],
                                                     gui=self.config_params['trainer_gui'],
                                                     seed=seed,
                                                     init_set=self.init_set,
                                                     term_set=self.term_set,
                                                     term_sampler = term_sampler,
                                                     option_guide = option_guide,
                                                     region_switch_point = region_switch_point,
                                                     demo_mode=self.config_params['debug'],
                                                     robot_config=self.config_params['robot'],
                                                     env_path=os.path.join(self.config_params['env_path'], self.config_params['env_name']+".stl"),
                                                     max_ep_len=self.model_params['max_ep_len'])
            
            self.n_actions = self.envs.get_attr('action_space')[0].shape[-1]
            self.action_noise = NormalActionNoise(mean=np.zeros(self.n_actions), sigma=0.1 * np.ones(self.n_actions))
            # self.action_noise = NormalActionNoise(mean=np.zeros(self.n_actions), sigma=0.0 * np.ones(self.n_actions))
            self.noise = VectorizedActionNoise(self.action_noise, self.model_params['train_envs'])

            self.eval_env = utils.env_utils.make_subproc_envs(num=1,
                                                     gui=False,
                                                     seed=seed,
                                                     init_set=self.init_set,
                                                     term_set=self.term_set,
                                                     term_sampler = term_sampler,
                                                     option_guide = option_guide,
                                                     region_switch_point = region_switch_point,
                                                     demo_mode=self.config_params['debug'],
                                                     robot_config=self.config_params['robot'],
                                                     env_path=os.path.join(self.config_params['env_path'], self.config_params['env_name']+".stl"),
                                                     max_ep_len=self.model_params['max_ep_len'],
                                                     monitor=True)
        else:
            print("Loading new funcs into envs")
            self.envs.set_attr('init_set', init_sampler)
            self.envs.set_attr('term_set', eval_func)
            self.envs.set_attr('term_sampler', term_sampler)
            self.envs.set_attr('option_guide',option_guide)
            self.envs.set_attr('region_switch_point',region_switch_point)
            self.envs.set_attr('info',utils.env_utils.create_env_info_dict())
            self.envs.reset()
            
            self.eval_env.set_attr('init_set', init_sampler)
            self.eval_env.set_attr('term_set', eval_func)
            self.eval_env.set_attr('term_sampler', term_sampler)
            self.eval_env.set_attr('option_guide',option_guide)
            self.eval_env.set_attr('region_switch_point',region_switch_point)
            self.eval_env.set_attr('info',utils.env_utils.create_env_info_dict())
            self.eval_env.reset()

        if tblog_prefix == '':
            tensorboard_log_dir = "./logs/sac/{}/".format(time.strftime("%d%y%h_%H%M"))
        else:
            tensorboard_log_dir = "./logs/sac/{}/{}/".format(tblog_prefix, time.strftime("%d%y%h_%H%M"))

        # if tblog_prefix == '':
        #     tensorboard_log_dir = "./logs/ppo/{}/".format(time.strftime("%d%y%h_%H%M"))
        # else:
        #     tensorboard_log_dir = "./logs/ppo/{}/{}/".format(tblog_prefix, time.strftime("%d%y%h_%H%M"))
        self.model = CustomSAC(policy="MlpPolicy",
                        env=self.envs,
                        verbose=1,
                        action_noise=self.action_noise,
                        learning_starts=int(self.model_params['warm_starts']*self.model_params['train_envs']),
                        # learning_starts=1000,
                        ent_coef='auto',
                        optimize_memory_usage=True,
                        gradient_steps=self.model_params['update_batch'],
                        buffer_size=self.model_params['buffer_size'],
                        # learning_rate=0.03,
                        learning_rate=0.003,
                        use_sde=True,
                        use_sde_at_warmup=False,
                        train_freq=(self.model_params['train_freq'],"step"),
                        tensorboard_log=tensorboard_log_dir,#"./logs/sac/{}/".format(time.strftime("%d%y%h_%H%M%S")),
                        batch_size=self.model_params['batch_size'],
                        seed=seed,
                        policy_kwargs=dict(net_arch=self.model_params['net_arch'],normalize_images=False))

        # self.model = CustomDDPG(policy="MlpPolicy",
        #                 env=self.envs,
        #                 verbose=1,
        #                 tensorboard_log=tensorboard_log_dir,
        #                 learning_rate = 0.001,
        #                 learning_starts = 1000)

        # self.model = PPO(policy="MlpPolicy",
        #                 env=self.envs,
        #                 verbose=1,
        #                 n_steps = 10,
        #                 # ent_coef='auto',
        #                 learning_rate=1e-2,
        #                 use_sde=True,
        #                 tensorboard_log=tensorboard_log_dir,#"./logs/sac/{}/".format(time.strftime("%d%y%h_%H%M%S")),
        #                 batch_size=self.model_params['batch_size'],
        #                 seed=seed,
        #                 policy_kwargs=dict(net_arch=self.model_params['net_arch'],normalize_images=False))

        print("Training...")
        log_name = self.envs.env_method("make_log_dir")[0]
        # callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=800, verbose=1)
        callback_on_best = CustomEarlyStop(reward_threshold=200, max_no_improvement_evals=5, min_evals=2, verbose = 0)
        self.best_model_path = os.path.join('./eval/best/',tblog_prefix)
        eval_callback = EvalCallback(eval_env=self.eval_env,
                                     n_eval_episodes=20,
                                     callback_on_new_best=callback_on_best,
                                     eval_freq=self.model_params['eval_freq'],
                                     log_path='./eval/logs/{}_{}'.format(self.config_params['robot']['name'], self.config_params['env_name']),
                                     best_model_save_path=self.best_model_path,
                                     verbose=1)
        # eval_callback = CustomEval(eval_env=self.eval_env,
        #                            n_eval_episodes=20,
        #                            callback_on_new_best=callback_on_best,
        #                            eval_freq=self.model_params['eval_freq'],
        #                            log_path='./eval/logs/{}_{}'.format(self.config_params['robot']['name'], self.config_params['env_name']),
        #                            best_model_save_path='./eval/best/{}_{}'.format(self.config_params['robot']['name'], self.config_params['env_name']),
        #                            verbose=1,
        #                            warm_starts=int(self.model_params['warm_starts']*self.model_params['train_envs']))
        
        callbacks = CallbackList([eval_callback, SummaryWriterCallback()])
        self.model.learn(total_timesteps=self.model_params['train_timesteps'],
                        log_interval=10,
                        tb_log_name='{}'.format(log_name),
                        reset_num_timesteps=False,
                        callback=callbacks)
        print("Eval...")
        cost, status = self.eval_policy()

        if status == 'okay':
            model_path = self.save()
        else:
            model_path = None
        
        os.system('rm {}'.format(os.path.join(self.best_model_path, "best_model.zip")))
        
        self.envs.close()
        self.envs = None
        torch.cuda.empty_cache()

        training_steps = self.model.num_timesteps
        try:
            eval_ep_len = [np.mean(eval_callback.evaluations_length[-1]), np.std(eval_callback.evaluations_length[-1])]
        except:
            eval_ep_len = [self.model_params['max_ep_len'], 0]

        return model_path, cost, status, eval_ep_len, training_steps
    
    def eval_policy(self):

        cost =  0

        obs = self.envs.reset()

        self.envs.set_attr('self.info',utils.misc.create_env_info_dict())
        self.envs.set_attr('self.log_arr',[])
        self.envs.set_attr('self.total_timesteps',0)
        
        t_end = time.time()+self.model_params['eval_time_limit']
        while time.time() < t_end:
            action, _states = self.model.predict(obs)
            obs, rewards, dones, info = self.envs.step(action)
        
        if self.config_params['visualize_policy']:
            policy_path = self.envs.env_method("save_log")
            plotter.make_plots(policy_path[0])

        done_ratio = [i['done_count']/i['total_dones'] if i['total_dones']>0 else float('inf') for i in info]
        print("Successful resets / total resets:{}".format(done_ratio))

        succ = np.where([i['done_count']>0 for i in info])[0]
        actions = [i['action_count']/i['done_count'] if i['done_count']>0 else float('inf') for i in info]

        if len(succ) > 0 and len(succ) >= int(self.config_params['eval_threshold'] * len(dones)):
            status = "okay"
            cost = self.cost_heuristic(actions)
        else:
            status = "timeout"
            cost = np.inf
        print(succ)
        print(cost, status)
        return cost, status

    def save(self):

        model_path = os.path.join(self.config_params["policy_folder"],
                                  time.strftime('%d%y%h_%H%M%S')+'.pt')
        
        best_model = os.path.join(self.best_model_path, "best_model.zip")
        if os.path.isfile(best_model):
            shutil.copyfile(best_model,model_path)
        else:
            self.model.save(model_path)

        return model_path

    def cost_heuristic(self, actions):
        # average action len of successful runs
        return np.array(actions).mean()

